package dai.samples.mybatis.interceptor;

import dai.samples.mybatis.entity.Page;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.ExecutorException;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.property.PropertyTokenizer;
import org.apache.ibatis.scripting.xmltags.ForEachSqlNode;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.transaction.Transaction;
import org.apache.ibatis.type.TypeHandler;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.springframework.stereotype.Component;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Properties;

/**
 *
 * @author daify
 * @date 2019-08-01
 **/
@Component
@Intercepts({
        // 拦截Executor的query的操作
        @Signature(
                type = Executor.class,
                method = "query",
                args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
        )
})
@Slf4j
public class PagePlugin implements Interceptor {
    
    @Override 
    public Object intercept(Invocation invocation) throws Throwable {
        Executor target = (Executor) invocation.getTarget();
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        MapperMethod.ParamMap<Object> mapParamMap = (MapperMethod.ParamMap<Object>) args[1];
        // 获取NAMESPACE和方法
        String[] nameSpaceId = ms.getId().split("\\.");
        if (!ArrayUtils.isEmpty(nameSpaceId) && nameSpaceId[nameSpaceId.length - 1].endsWith("MyPage")) {
            // 假装之前有类型判断，太晚了不想写了
            Page page = (Page) mapParamMap.get("page");
            BoundSql boundSql = ms.getBoundSql(mapParamMap);
            int count = getCount(target,ms,mapParamMap,boundSql);
            args[2] = new RowBounds(page.getStart(), page.getPageSize());
            page.setPageTotal(count);
        }
        // 主要的业务
        return invocation.proceed();
    }

    /**
     * 计算总数
     * @param target
     * @param ms
     * @param params
     * @param boundSql
     * @return
     * @throws SQLException
     */
    private int getCount(Executor target,MappedStatement ms,MapperMethod.ParamMap<Object> params,BoundSql boundSql) throws SQLException {
        Transaction transaction = target.getTransaction();
        Connection connection = transaction.getConnection();

        String sql = boundSql.getSql().trim();
        // 查询数量
        String countSql = "select count(0) from ( "+ sql + " )";
        PreparedStatement cStatement = null;
        ResultSet rs = null;
        int totalCount=0;
        try {
            // 执行总数的SQL
            cStatement = connection.prepareStatement(countSql);
            // 设置SQL
            setParameters(cStatement,ms,boundSql,boundSql.getParameterObject());
            // 执行查询
            rs = cStatement.executeQuery();
            if (rs.next()) {
                // 获得结果
                totalCount = rs.getInt(1);
            }
        } catch (Exception e) {
            log.info(e.getMessage());
        } finally {
            try {
                rs.close();
                cStatement.close();
            } catch (Exception e) {
                log.error("SQLException", e);
            }
        }
        return totalCount;
    }


    private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
                               Object parameterObject) throws SQLException {
        ErrorContext.instance().activity("setting parameters")
                .object(mappedStatement.getParameterMap().getId());
        // 获得参数映射
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        if (parameterMappings != null) {
            Configuration configuration = mappedStatement.getConfiguration();
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            // 元数据参数
            MetaObject metaObject = parameterObject == null ? null : configuration.newMetaObject(parameterObject);
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    // 获得参数名称
                    String propertyName = parameterMapping.getProperty();
                    PropertyTokenizer prop = new PropertyTokenizer(propertyName);
                    if (parameterObject == null) {
                        value = null;
                    } else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                        value = parameterObject;
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (propertyName.startsWith(ForEachSqlNode.ITEM_PREFIX) && boundSql.hasAdditionalParameter(prop.getName())) {
                        value = boundSql.getAdditionalParameter(prop.getName());
                        if (value != null) {
                            value = configuration.newMetaObject(value)
                                    .getValue(propertyName.substring(prop.getName().length()));
                        }
                    } else {
                        value = metaObject == null ? null : metaObject.getValue(propertyName);
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    if (typeHandler == null) {
                        throw new ExecutorException(
                                "There was no TypeHandler found for parameter " + propertyName + " of statement " + mappedStatement.getId());
                    }
                    typeHandler.setParameter(ps, i + 1, value, parameterMapping.getJdbcType());
                }
            }
        }
    }


    /**
     * 将这个类作为包装类假如拦截链
     * @param o
     * @return
     */
    @Override 
    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    /**
     * 设置参数
     * @param properties
     * @return
     */
    @Override 
    public void setProperties(Properties properties) {

    }
}
